'''
Game of Life 
Date 2019/4/8
Version 0.2
'''

import numpy as np
import matplotlib.pyplot as plt 

class GameOfLife(object):

    def __init__(self,cells_shape=(20,20),init_state='random',bnd_cond='null'):
        ''' initialization '''
        self.steps = 0  # recode the steps
        self.mask = np.ones((3,3)) # calculate the sum of the neighborhood
        self.mask[1,1] = 0 # itself is zero
        # the size of the area
        self.width = cells_shape[0]
        self.height = cells_shape[1]
        # self.extend_width = self.width + 2  
        # self.extend_height = self.width + 2 
        # add two rows and two columns to calculate the neighborhood and boundary 
        self.cells = np.zeros((self.width + 2, self.width + 2))

        self.init_state  = init_state 
        if self.init_state == 'random':
            self.cells[1:self.width+1,1:self.height+1] = np.random.randint(2,size=(self.width, self.width))
        elif self.init_state == 'Glider':
            glider = np.zeros((3,3))
            glider[0,1]=glider[1,2]=glider[2,0]=glider[2,1]=glider[2,2]=1
            self.cells[3:6,3:6] = glider

        self.bnd_cond = bnd_cond
        if self.bnd_cond == 'null':
            pass # noting to do 
        elif self.bnd_cond == 'period':
            # periodic boundary conditions
            self.cells[0,0] = self.cells[-2,-2]
            self.cells[0,-1] = self.cells[-2,1]
            self.cells[0,1:-1] = self.cells[-2,1:-1]
            self.cells[-1,1:-1] = self.cells[1,1:-1]
            self.cells[-1,0] = self.cells[1,-2]
            self.cells[-1,-1] = self.cells[1,1]
            self.cells[1:-1,0] = self.cells[1:-1,-2]
            self.cells[1:-1,-1] = self.cells[1:-1,1]

        self.fig = plt.figure()  #num=np.random.randint(0,100)
        self.ax1 = self.fig.add_subplot(111)
        self.ax1.set_aspect(1.0)
        self.ax1.set_xlim([0,self.width])
        self.ax1.set_ylim([self.height,0])
        #self.ax1.pcolor(draw_cells,edgecolor='k')
        #plt.show()

    def update_state(self):
        temp_cells = np.zeros((self.width + 2, self.width + 2 ))
        for i in range(1,self.width+1):
            for j in range(1,self.height+1):
                ngb = self.cells[i-1:i+2,j-1:j+2]
                ngb_value = np.sum(ngb*self.mask) # calculate the total live cells in the neighborhood
                if ngb_value == 3:
                    temp_cells[i,j] = 1
                elif ngb_value == 2:
                    temp_cells[i,j] = self.cells[i,j]
                else:
                    temp_cells[i,j] = 0
        self.cells = temp_cells
        if self.bnd_cond == 'null':
            pass # noting to do 
        elif self.bnd_cond == 'period':
            self.cells[0,0] = self.cells[-2,-2]
            self.cells[0,-1] = self.cells[-2,1]
            self.cells[0,1:-1] = self.cells[-2,1:-1]
            self.cells[-1,1:-1] = self.cells[1,1:-1]
            self.cells[-1,0] = self.cells[1,-2]
            self.cells[-1,-1] = self.cells[1,1]
            self.cells[1:-1,0] = self.cells[1:-1,-2]
            self.cells[1:-1,-1] = self.cells[1:-1,1]
        self.steps += 1

    def plot_state(self):
        self.ax1.set_title('Steps: {}'.format(self.steps))
        # for the hot cmap, 0 is black, 1 is white
        draw_cells = 1 - self.cells[1:self.width+1, 1:self.height+1]
        self.ax1.pcolor(draw_cells,cmap='hot',edgecolor='k')
        #plt.show()
    
    def update_and_plot(self,steps):
        #plt.ion()
        for _ in range(steps):
            self.plot_state()
            plt.pause(0.2)
            self.update_state()
        #plt.ioff
        plt.show()

if __name__ == '__main__':
    print('hello')
    game1 = GameOfLife()
    # game1.plot_state()
    game1.update_and_plot(10)

    #game2 = GameOfLife(init_state='Glider', bnd_cond = 'period')
    # game2.update_and_plot(100)